{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Graph Parallelization"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Run sequentially"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.display import Image, display\n",
    "from typing_extensions import TypedDict\n",
    "from langgraph.graph import StateGraph, START, END\n",
    "\n",
    "# define a graph state\n",
    "class State(TypedDict):\n",
    "    value: str\n",
    "\n",
    "def a(state: State):\n",
    "    print(f\"Adding 'A' to state {state['value']}\")\n",
    "    return {\"value\": [\"A\"]}\n",
    "\n",
    "def b(state: State):\n",
    "    print(f\"Adding 'B' to state {state['value']}\")\n",
    "    return {\"value\": [\"B\"]}\n",
    "\n",
    "def c(state: State):\n",
    "    print(f\"Adding 'C' to state {state['value']}\")\n",
    "    return {\"value\": [\"C\"]}\n",
    "\n",
    "def d(state: State):\n",
    "    print(f\"Adding 'D' to state {state['value']}\")\n",
    "    return {\"value\": [\"D\"]}\n",
    "\n",
    "\n",
    "\n",
    "builder = StateGraph(State)\n",
    "\n",
    "builder.add_node(a)\n",
    "builder.add_node(b)\n",
    "builder.add_node(c)\n",
    "builder.add_node(d)\n",
    "\n",
    "builder.add_edge(START, \"a\")\n",
    "builder.add_edge(\"a\", \"b\")\n",
    "builder.add_edge(\"b\", \"c\")\n",
    "builder.add_edge(\"c\", \"d\")\n",
    "builder.add_edge(\"d\", END)\n",
    "graph = builder.compile()\n",
    "\n",
    "display(Image(graph.get_graph().draw_mermaid_png()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "graph.invoke({\"value\": []})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Run in parallel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "builder = StateGraph(State)\n",
    "\n",
    "builder.add_node(a)\n",
    "builder.add_node(b)\n",
    "builder.add_node(c)\n",
    "builder.add_node(d)\n",
    "\n",
    "builder.add_edge(START, \"a\")\n",
    "builder.add_edge(\"a\", \"b\")\n",
    "builder.add_edge(\"a\", \"c\")\n",
    "builder.add_edge(\"b\", \"d\")\n",
    "builder.add_edge(\"c\", \"d\")\n",
    "builder.add_edge(\"d\", END)\n",
    "graph = builder.compile()\n",
    "\n",
    "display(Image(graph.get_graph().draw_mermaid_png()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "graph.invoke({\"value\": []})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "let's redefine Graph's State"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import operator\n",
    "from typing import Annotated\n",
    "\n",
    "class State(TypedDict):\n",
    "    value: Annotated[list, operator.add]\n",
    "\n",
    "builder = StateGraph(State)\n",
    "\n",
    "builder.add_node(a)\n",
    "builder.add_node(b)\n",
    "builder.add_node(c)\n",
    "builder.add_node(d)\n",
    "\n",
    "builder.add_edge(START, \"a\")\n",
    "builder.add_edge(\"a\", \"b\")\n",
    "builder.add_edge(\"a\", \"c\")\n",
    "builder.add_edge(\"b\", \"d\")\n",
    "builder.add_edge(\"c\", \"d\")\n",
    "builder.add_edge(\"d\", END)\n",
    "graph = builder.compile()\n",
    "\n",
    "display(Image(graph.get_graph().draw_mermaid_png()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "graph.invoke({\"value\": []})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "nodes \"B\" and \"C\" are executed concurrently in the same superstep, meaning \"B\" & \"C\" are in the same transactional context, so if one fails, both wont update the state."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Extend B route"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def b_1(state: State):\n",
    "    print(f\"Adding 'B_1' to state {state['value']}\")\n",
    "    return {\"value\": [\"B_1\"]}\n",
    "\n",
    "def b_2(state: State):\n",
    "    print(f\"Adding 'B_2' to state {state['value']}\")\n",
    "    return {\"value\": [\"B_2\"]}\n",
    "\n",
    "builder = StateGraph(State)\n",
    "\n",
    "builder.add_node(a)\n",
    "builder.add_node(b_1)\n",
    "builder.add_node(b_2)\n",
    "builder.add_node(c)\n",
    "builder.add_node(d)\n",
    "\n",
    "builder.add_edge(START, \"a\")\n",
    "builder.add_edge(\"a\", \"b_1\")\n",
    "builder.add_edge(\"a\", \"c\")\n",
    "builder.add_edge(\"b_1\", \"b_2\")\n",
    "builder.add_edge(\"b_2\", \"d\")\n",
    "builder.add_edge(\"c\", \"d\")\n",
    "builder.add_edge(\"d\", END)\n",
    "graph = builder.compile()\n",
    "\n",
    "display(Image(graph.get_graph().draw_mermaid_png()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "graph.invoke({\"value\": []})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Only B_1 & C are executed concurrently in the same superstep.\n",
    "\n",
    "Let's force D to wait until B_1 + B_2 AND C are completed."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "builder = StateGraph(State)\n",
    "\n",
    "builder.add_node(a)\n",
    "builder.add_node(b_1)\n",
    "builder.add_node(b_2)\n",
    "builder.add_node(c)\n",
    "builder.add_node(d)\n",
    "\n",
    "builder.add_edge(START, \"a\")\n",
    "builder.add_edge(\"a\", \"b_1\")\n",
    "builder.add_edge(\"a\", \"c\")\n",
    "builder.add_edge(\"b_1\", \"b_2\")\n",
    "builder.add_edge([\"b_2\", \"c\"], \"d\")\n",
    "builder.add_edge(\"d\", END)\n",
    "graph = builder.compile()\n",
    "\n",
    "display(Image(graph.get_graph().draw_mermaid_png()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "graph.invoke({\"value\": []})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Conditional Branching"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import operator\n",
    "from typing import Annotated, Sequence\n",
    "\n",
    "from typing_extensions import TypedDict\n",
    "\n",
    "from langgraph.graph import StateGraph, START, END\n",
    "\n",
    "\n",
    "class State(TypedDict):\n",
    "    value: Annotated[list, operator.add]\n",
    "    route: str\n",
    "\n",
    "def a(state: State):\n",
    "    print(f\"Adding 'A' to state {state['value']}\")\n",
    "    return {\"value\": [\"A\"]}\n",
    "\n",
    "def b(state: State):\n",
    "    print(f\"Adding 'B' to state {state['value']}\")\n",
    "    return {\"value\": [\"B\"]}\n",
    "\n",
    "def c(state: State):\n",
    "    print(f\"Adding 'C' to state {state['value']}\")\n",
    "    return {\"value\": [\"C\"]}\n",
    "\n",
    "def d(state: State):\n",
    "    print(f\"Adding 'D' to state {state['value']}\")\n",
    "    return {\"value\": [\"D\"]}\n",
    "\n",
    "def e(state: State):\n",
    "    print(f\"Adding 'E' to state {state['value']}\")\n",
    "    return {\"value\": [\"E\"]}\n",
    "\n",
    "\n",
    "builder = StateGraph(State)\n",
    "builder.add_node(a)\n",
    "builder.add_node(b)\n",
    "builder.add_node(c)\n",
    "builder.add_node(d)\n",
    "builder.add_node(e)\n",
    "\n",
    "builder.add_edge(START, \"a\")\n",
    "\n",
    "\n",
    "def route_bc_or_cd(state: State) -> Sequence[str]:\n",
    "    if state[\"route\"] == \"bc\":\n",
    "        return [\"b\", \"c\"]\n",
    "    return [\"c\", \"d\"]\n",
    "    \n",
    "intermediates = [\"b\", \"c\", \"d\"]\n",
    "builder.add_conditional_edges(\n",
    "    \"a\",\n",
    "    route_bc_or_cd,\n",
    "    intermediates\n",
    ")\n",
    "for node in intermediates:\n",
    "    builder.add_edge(node, \"e\")\n",
    "\n",
    "builder.add_edge(\"e\", END)\n",
    "\n",
    "graph = builder.compile()\n",
    "\n",
    "display(Image(graph.get_graph().draw_mermaid_png()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "graph.invoke({\"value\": [], \"route\": \"bc\"})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "graph.invoke({\"value\": [], \"route\": \"cd\"})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Practical example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_core.messages import HumanMessage, SystemMessage\n",
    "\n",
    "from langchain_community.document_loaders import WikipediaLoader\n",
    "from langchain_community.tools import TavilySearchResults\n",
    "\n",
    "\n",
    "from langchain_openai import ChatOpenAI\n",
    "llm = ChatOpenAI(model=\"gpt-4o-mini\") \n",
    "\n",
    "class State(TypedDict):\n",
    "    question: str\n",
    "    answer: str\n",
    "    context: Annotated[list, operator.add]\n",
    "\n",
    "\n",
    "def search_web(state):\n",
    "    tavily_search = TavilySearchResults(max_results=5)\n",
    "    search_docs = tavily_search.invoke(state['question'])\n",
    "\n",
    "    results = [\n",
    "        f'<Document>\\n{doc[\"content\"]}\\n</Document>'\n",
    "        for doc in search_docs\n",
    "    ]\n",
    "\n",
    "    return {\"context\": results}\n",
    "\n",
    "\n",
    "def search_wikipedia(state):\n",
    "    search_docs = WikipediaLoader(query=state['question'], load_max_docs=5).load()\n",
    "\n",
    "    results = [\n",
    "        f'<Document>\\n{doc.page_content}\\n</Document>'\n",
    "        for doc in search_docs\n",
    "    ]\n",
    "    \n",
    "    return {\"context\": results} \n",
    "\n",
    "\n",
    "def generate_answer(state):\n",
    "    # System message\n",
    "    system_message = SystemMessage(content=(\"\"\"\n",
    "        You are an AI assistant that answers questions based on the provided documents.\n",
    "        Guidelines:\n",
    "            - Provide direct, concise, and accurate answers.\n",
    "            - When possible, cite the relevant document or URL.\n",
    "            - If multiple documents contain relevant information, synthesize the best answer.\n",
    "\n",
    "        If a document contains conflicting information, mention both perspectives.\n",
    "    \"\"\"))\n",
    "\n",
    "\n",
    "    formatted_docs = \"\\n\".join(\n",
    "        [f\"- {doc}\" for doc in state[\"context\"]]\n",
    "    )\n",
    "\n",
    "    system_context = SystemMessage(content=(f\"Use the following documents as context for your response:\\n\\n{formatted_docs}\"))\n",
    "\n",
    "    answer = llm.invoke([system_message] + [system_context] + [HumanMessage(content=state[\"question\"])])\n",
    "    \n",
    "    # Append it to state\n",
    "    return {\"answer\": answer}\n",
    "\n",
    "# Add nodes\n",
    "builder = StateGraph(State)\n",
    "\n",
    "# Initialize each node with node_secret \n",
    "builder.add_node(\"search_web\",search_web)\n",
    "builder.add_node(\"search_wikipedia\", search_wikipedia)\n",
    "builder.add_node(\"generate_answer\", generate_answer)\n",
    "\n",
    "# Flow\n",
    "builder.add_edge(START, \"search_wikipedia\")\n",
    "builder.add_edge(START, \"search_web\")\n",
    "builder.add_edge(\"search_wikipedia\", \"generate_answer\")\n",
    "builder.add_edge(\"search_web\", \"generate_answer\")\n",
    "builder.add_edge(\"generate_answer\", END)\n",
    "graph = builder.compile()\n",
    "\n",
    "display(Image(graph.get_graph().draw_mermaid_png()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "result = graph.invoke({\"question\": \"should i invest in AI stocks now\"})\n",
    "print(result['answer'].content)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "result"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
